import torch
import numpy as np
from abc import abstractmethod
from scipy import special
from util import argmax_random_tiebreak

# RL ENVIRONMENT OBJECT ============================================================
class BinaryRewardEnv:
    def __init__(self, success_p, num_arms=100, rng_gen=None):
        self.num_arms = num_arms
        self.success_p = success_p

        if rng_gen is None:
            self.seed = 49248204
            self.rng_gen = np.random.default_rng(self.seed)
        else:
            self.rng_gen = rng_gen
        
    def generate_reward(self, arms, t):
        tmp_success_p = self.success_p[arms]
        reward = self.rng_gen.binomial(1, tmp_success_p)
        return reward

    def get_expected_reward(self, arm, t):
        return self.success_p[arm]


class BinaryRewardEnv_impute:
    def __init__(self, success_p, T, X, Y, num_arms=100):
        assert success_p.shape[1] >= T

        self.X = X # shape: timesteps x (X dimension)
        self.num_arms = num_arms
        self.success_p = success_p
        self.potential_outcomes = Y
        
    def get_context(self, t):
        if self.X is None:
            return None
        return self.X[:,[t],:]
    
    def generate_reward(self, arm, t): 
        return self.potential_outcomes[arm,t]

    def get_expected_reward(self, arm, t): 
        return self.success_p[arm, t]


class BinaryRewardEnv_horizonDependent:
    def __init__(self, success_p, T, num_arms=100, rng_gen=None):
        self.num_arms = num_arms
        self.success_p = success_p

        if rng_gen is None:
            self.seed = 49248204
            self.rng_gen = np.random.default_rng(self.seed)
        else:
            self.rng_gen = rng_gen

        # generate table of potential outcomes
        potential_outcomes = []
        for arm in range(num_arms):
            PO_row = self.rng_gen.binomial(1, self.success_p[arm], size=(T))
            potential_outcomes.append(PO_row)

        self.potential_outcomes = np.vstack(potential_outcomes)        
        
    def generate_reward(self, arm, t):
        return self.potential_outcomes[arm][t]

    def get_expected_reward(self, arm, t):
        return self.generate_reward(arm, t)
        
# GET BANDIT ENVS ============================================================

def get_bandit_envs(num_arms, T, N_monte_carlo, success_p_all, seed=879437260, horizonDependent=False):
    # deterministic, and in the sense that if you re-run with different
    # values of N_monte_carlo = N1 < N2, then the first N1 bandit envs
    # will be the same for both runs
    all_bandit_envs = []
    rng_env = np.random.default_rng(seed)

    for i in range(N_monte_carlo):
        chosen_arms = rng_env.choice(np.arange(len(success_p_all)), num_arms)
        if horizonDependent:
            bandit_env_tmp = BinaryRewardEnv_horizonDependent(success_p_all[chosen_arms], T, num_arms, rng_env)
        else:
            bandit_env_tmp = BinaryRewardEnv(success_p_all[chosen_arms], num_arms, rng_env)
        all_bandit_envs.append( (bandit_env_tmp, chosen_arms) )
    return all_bandit_envs


def get_bandit_envs_from_data_dict(data_dict, num_arms, T, N_monte_carlo, seed=879437260, context=False):
    all_bandit_envs = []
    generator = torch.Generator()
    generator.manual_seed(seed)
        
    for i in range(N_monte_carlo):
        num_rows = data_dict['click_rate'].shape[0]
        num_cols = data_dict['click_rate'].shape[1]

        # which arms/rows
        chosen_arms = torch.randperm(num_rows, generator=generator)[:num_arms]

        # which columns, after choosing arms/rows
        chosen_idxs = torch.randint(high=num_cols, size=(num_arms, T),
                    dtype=torch.int64, generator=generator)
        
        subset_Z = data_dict['Z'][chosen_arms]
        
        subset_X = None
        if 'X' in data_dict.keys():
            subset_X = torch.gather(data_dict['X'][chosen_arms], dim=1, index=chosen_idxs.unsqueeze(-1))
        
        subset_click_rates = torch.gather(data_dict['click_rate'][chosen_arms], dim=1, index=chosen_idxs)        
        bandit_env_tmp = BinaryRewardEnv_impute(subset_click_rates, T, subset_X)
        all_bandit_envs.append( (bandit_env_tmp, np.array(chosen_arms)) )
    return all_bandit_envs

# Note that this return tuples of (bandit env, data_dict) rather than (bandit_env, arms)
def get_bandit_envs_from_dgp(dgp_fn, num_arms, T, N_monte_carlo, seed=879437260, context=False):   
    all_bandit_envs = []
    generator = torch.Generator()
    generator.manual_seed(seed)
        
    for i in range(N_monte_carlo):
        data_dict = dgp_fn(num_arms, T, generator)
        bandit_env_tmp = BinaryRewardEnv_impute(data_dict['click_rate'], T, data_dict['X'], data_dict['Y'], num_arms=num_arms)
        #$ do something about this weird default...
        all_bandit_envs.append( (bandit_env_tmp, data_dict ) )
    return all_bandit_envs


# ABSTRACT BANDIT ALGORITHM CLASS ============================================================

class BanditAlgorithm:
    def __init__(self, num_arms=100, seed=None):
        self.num_arms = num_arms
        
        if seed is None:
            self.seed = 457323948
        else:
            self.seed = seed
        self.rng_gen = np.random.default_rng(self.seed)

    
    @abstractmethod
    def update_algorithm(self, arms, rewards):
        pass

    @abstractmethod        
    def sample_action(self):
        pass

class MarginalAlg(BanditAlgorithm):
    def __init__(self, seq_model, Z_representation, num_arms):
        self.seq_model = seq_model
        self.Z_representation = Z_representation
        if Z_representation is not None:
            assert len(Z_representation) == num_arms

    def update_algorithm(self, arm, reward):
        pass
    
    def sample_action(self, return_extra=False):
        p_pred = self.seq_model.top_layer(self.Z_representation)
        best_arm = torch.argmax(p_pred).item()
        if return_extra:
            return best_arm, p_pred
        return best_arm

class MarginalAlgWithContext(BanditAlgorithm):
    def __init__(self, model, Z_representation, num_arms):
        self.model = model
        self.Z_representation = Z_representation
        if Z_representation is not None:
            assert len(Z_representation) == num_arms

    def update_algorithm(self, arm, reward, X):
        pass
    
    def sample_action(self, X, return_extra=False):
        p_pred = self.model(self.Z_representation, X)
        best_arm = torch.argmax(p_pred).item()
        if return_extra:
            return best_arm, p_pred
        return best_arm

class LinearGaussianContextTS(BanditAlgorithm):
    def add_const_if_necessary(self, X):
        if not self.add_const_feature:
            return X
        if len(X.shape) == 2:
            A,D = X.shape
            return torch.concatenate([X, torch.ones(A,1)], 1)
        elif len(X.shape) == 3:
            A,T,D = X.shape
            return torch.concatenate([X, torch.ones(A,T,1)], 2)
        else:
            raise ValueError('something is wrong')
            
    def __init__(self, num_arms, X, hyparam_dict, seed=None, max_T=None, add_const_feature=True, prior_mean=None):
        super(LinearGaussianContextTS, self).__init__(num_arms, seed)
        
        self.t = 0

        A, T, raw_Xdim = X.shape
        assert A == num_arms
        self.num_arms = num_arms

        if max_T is None:
            max_T = T
        self.max_T = max_T
        assert max_T <= T

        self.add_const_feature = add_const_feature
        # add a constant to X, for all features
        self.features = self.add_const_if_necessary(X)
        self.Xdim = self.features.shape[-1]
        
        # LINEAR GAUSSIAN PARAMS
        # from https://arxiv.org/pdf/1802.09127
        self.lam = hyparam_dict['lam']          # lambda: prior precision \Gamma_0 = I * lambda
        self.sig = hyparam_dict['sig']                # sigma

        # one per arm: model arms separately
        self.prior_beta_mean = torch.zeros(num_arms, self.Xdim)
        if prior_mean is not None and add_const_feature:
            self.prior_beta_mean[:,-1] = prior_mean
            
        self.prior_beta_cov = torch.eye(self.Xdim).repeat(num_arms,1,1) / self.lam
        self.prior_inv_beta_cov = torch.eye(self.Xdim).repeat(num_arms,1,1) * self.lam

        # calculate posterior from all data and prior
        # maybe this is more numerically stable (untested)
        self.beta_mean = torch.zeros(num_arms, self.Xdim)
        self.beta_cov = torch.eye(self.Xdim).repeat(num_arms,1,1)
        self.inv_beta_cov = torch.eye(self.Xdim).repeat(num_arms,1,1) 
        
        self.prev_Ys = torch.zeros(num_arms, max_T)
        self.obs_timesteps = torch.zeros(num_arms, max_T).int() # timesteps corresponding to observations, per arm
        self.obs_count = torch.zeros(num_arms).int() # number of total observations, per arm
            
    def update_algorithm(self, arm, reward, this_X):
        if not isinstance(reward, int):
            try:
                reward = reward.item()
            except:
                raise ValueError('reward must be int')
        self.prev_Ys[arm,self.obs_count[arm]] = reward
        self.obs_timesteps[arm, self.obs_count[arm]] = self.t
        self.obs_count[arm] += 1
        
        # do a posterior update for this arm using prior and all history for this arm:
        X = torch.index_select(self.features[arm], 0, self.obs_timesteps[arm,:self.obs_count[arm]])
        Y = self.prev_Ys[arm,:self.obs_count[arm]]
        
        XtX = X.T@X
        post_cov = torch.inverse(self.prior_inv_beta_cov[arm] + XtX)
        post_mean = post_cov @ ( 
                         self.prior_inv_beta_cov[arm]@self.prior_beta_mean[arm].unsqueeze(1) + 
                         (X.T@Y).unsqueeze(1)).flatten()
        
        self.beta_cov[arm] = post_cov
        self.beta_mean[arm] = post_mean
        self.t += 1
        
    def sample_action(self, X, return_extra=False):
        feat = self.add_const_if_necessary(X[:,0,:])
        assert torch.allclose(self.features[:,self.t,:], feat)
        arm_means = torch.zeros(self.num_arms)
        arm_vars = torch.zeros(self.num_arms)
        for arm in torch.arange(self.num_arms):
            arm_vars[arm] = self.sig**2 * feat[arm] @ self.beta_cov[arm] @ feat[arm].T
            arm_means[arm] = feat[arm] @ self.beta_mean[arm]
        samples = torch.normal(mean=arm_means, std=arm_vars**0.5)
        
        if return_extra:
            return samples.argmax().item(), {'arm_means':arm_means, 'arm_vars':arm_vars}
        return samples.argmax().item()



from torch.distributions import MultivariateNormal
class NeuralLinearArmsTogetherAlg(BanditAlgorithm):
    def __init__(self, num_arms, features, hyparam_dict, seed=None, max_T=10000):
        super(NeuralLinearArmsTogetherAlg, self).__init__(num_arms, seed)
        
        # add a constant?
        n_samples = len(features)
        phi_0 = torch.ones(n_samples)
        self.features = torch.concatenate([phi_0.unsqueeze(-1), features], 1)
        
        self.Xdim = self.features.shape[1]

        self.noise_var = hyparam_dict['noise_var']        

        self.prior_beta_mean = torch.zeros(self.Xdim)
        self.prior_beta_cov = torch.eye(self.Xdim)
        self.prior_inv_beta_cov = torch.eye(self.Xdim)

        # calculate posterior from all data and prior
        # maybe this is more numerically stable (untested)
        self.beta_mean = torch.zeros(self.Xdim)
        self.beta_cov = torch.eye(self.Xdim)
        self.inv_beta_cov = torch.eye(self.Xdim)        
        
        self.prev_arms = torch.zeros(max_T).int()
        self.prev_Ys = torch.zeros(max_T)
        self.obs_count = 0
            
    def update_algorithm(self, arm, reward):
        if not isinstance(reward, int):
            try:
                reward = reward.item()
            except:
                raise ValueError('reward must be int')
        self.prev_Ys[self.obs_count] = reward
        self.prev_arms[self.obs_count] = arm
        self.obs_count += 1
        # do a posterior update
        X = torch.index_select(self.features, 0, self.prev_arms[:self.obs_count])
        Y = self.prev_Ys[:self.obs_count]
        
        XtX = X.T@X
        post_cov = torch.inverse(self.prior_inv_beta_cov + 1 / self.noise_var * XtX)
        post_mean = post_cov @ ( 
                         self.prior_inv_beta_cov@self.prior_beta_mean.unsqueeze(1) + 
                         (1 / self.noise_var * X.T@Y).unsqueeze(1)).flatten()
        
        self.beta_cov = post_cov
        self.beta_mean = post_mean
            
    def sample_action(self):
        # sample arm rewards
        beta_dist = MultivariateNormal(self.beta_mean, self.beta_cov)
        beta_sample = beta_dist.sample()
        return (self.features @ beta_sample).argmax().item()


class GaussianGaussianAlg(BanditAlgorithm):
    def __init__(self, marginal_preds, num_arms, hyparam_dict, seed=None):
        super(NeuralLinearAlg, self).__init__()
        self.prior_means = marginal_preds

        if 'sigma_squared' in hyparam_dict.keys():
            hyparam_dict['prior_var'] = hyparam_dict['sigma_squared']
            hyparam_dict['noise_var'] = hyparam_dict['s_squared']
        self.prior_vars = np.ones(num_arms) * hyparam_dict['prior_var']
        self.noise_var = hyparam_dict['noise_var']

        assert len(marginal_preds) == num_arms

        self.post_means = np.copy(self.prior_means)   # set to prior mean
        self.post_vars = np.copy(self.prior_vars)     # set to prior variance
    
        self.num_obs = np.zeros(num_arms)
        self.reward_sum = np.zeros(num_arms)
    
    def update_algorithm(self, arm, reward):
        self.num_obs[arm] += 1
        self.reward_sum[arm] += reward
    
        self.post_vars[arm] = 1 / ( 1/self.prior_vars[arm] + self.num_obs[arm] / self.noise_var )
        self.post_means[arm] = self.post_vars[arm] * (self.prior_means[arm] / self.prior_vars[arm] + self.reward_sum[arm] / self.noise_var)
    
    def sample_action(self):    
        samples = self.rng_gen.normal(self.post_means, self.post_vars ** 0.5)
        return samples.argmax().item()




# POSTERIOR HALLUCINATION BANDIT ALGORITHMS ============================================================


class GreedyPosteriorMeanAlg(BanditAlgorithm):
    def __init__(self, seq_model, Z_representation, num_arms):
        self.seq_model = seq_model
        self.Z_representation = Z_representation
        self.curr_state = seq_model.init_model_states(batch_size=num_arms)
        if Z_representation is not None:
            assert len(Z_representation) == num_arms

    def update_algorithm(self, arm, reward):
        arm_curr_state = self.curr_state[[arm]]
        arm_new_state = self.seq_model.update_state(arm_curr_state, 
                                                    torch.tensor([reward]))
        self.curr_state[arm] = arm_new_state[0]
    
    def sample_action(self, return_extra=False):
        if self.Z_representation is not None:
            state = torch.cat([self.Z_representation, self.curr_state], 1)
        else:
            state = self.curr_state
        p_pred = self.seq_model.top_layer(state)
        best_arm = torch.argmax(p_pred).item()
        if return_extra:
            return best_arm, p_pred
        return best_arm


class SampledGreedyPosteriorMeanAlg(BanditAlgorithm):
    def __init__(self, seq_model, Z_representation, num_arms, num_samples):
        self.seq_model = seq_model
        self.num_samples = num_samples
        self.Z_representation = Z_representation
        self.curr_state = seq_model.init_model_states(batch_size=num_arms)
        if Z_representation is not None:
            assert len(Z_representation) == num_arms

    def update_algorithm(self, arm, reward):
        arm_curr_state = self.curr_state[[arm]]
        arm_new_state = self.seq_model.update_state(arm_curr_state, 
                                                    torch.tensor([reward]))
        self.curr_state[arm] = arm_new_state[0]
    
    def sample_action(self, return_extra=False):
        if self.Z_representation is not None:
            state = torch.cat([self.Z_representation, self.curr_state], 1)
        else:
            state = self.curr_state
        p_pred = self.seq_model.top_layer(state)
        sampled_pred = torch.bernoulli(p_pred.flatten().unsqueeze(-1).repeat(1, self.num_samples)).mean(1)
        best_arm = torch.argmax(sampled_pred).item()
        if return_extra:
            return best_arm, sampled_pred
        return best_arm

class GreedySequentialWithContext(BanditAlgorithm):
    def __init__(self, model, Z, num_arms, T, X):
        
        # imputation sequential model with context
        self.model = model 
        self.Z = Z
        self.T = T
        self.num_arms = num_arms
        assert Z.shape[0] == self.num_arms
        # not really used, currently?
        self.X = X 
        self.Xdim = X.shape[-1]
        
        # things that increment / accumulate
        self.t = 0
        # accumulate observations observed by bandit algorithm
        self.hist_X = torch.zeros(self.num_arms, self.T, self.Xdim)
        self.hist_Y = torch.zeros(self.num_arms, self.T)
        self.hist_mask = torch.zeros(self.num_arms, self.T)
        
    def update_algorithm(self, arm, reward, X): 
        # todo check shapes here
        self.hist_X[arm,[self.t],:] = X[arm]
        self.hist_Y[arm,self.t] = reward
        self.hist_mask[arm,self.t] = 1
        self.t += 1
    
    def sample_action(self, X, return_extra=False):
        # put history into model, generate for current X, output that directly
        # might want to refactor some of this + put it into model instead
        curr_state = self.model.get_state(self.hist_X, self.hist_Y)
        input_ = torch.cat([self.Z.unsqueeze(1), X, curr_state.unsqueeze(1)], 2)
        p_hat_pred = self.model.top_layer(input_).squeeze(2)
        best_arm = torch.argmax(p_hat_pred).item()
        if return_extra:
            return best_arm, p_hat_pred
        return best_arm    
    
class PosteriorHallucinationAlg(BanditAlgorithm):
    def __init__(self, seq_model, Z_representation, num_arms, 
                 num_imagined=100, seed=None, randomly_break_ties=False):
        super(PosteriorHallucinationAlg, self).__init__(num_arms, seed)
        self.seq_model = seq_model
        self.Z_representation = Z_representation
        self.curr_state = seq_model.init_model_states(batch_size=num_arms)
        self.num_imagined = num_imagined
        self.randomly_break_ties = randomly_break_ties
        if Z_representation is not None:
            assert len(Z_representation) == num_arms

    
    def update_algorithm(self, arm, reward):
        arm_curr_state = self.curr_state[[arm]]
        arm_new_state = self.seq_model.update_state(arm_curr_state, 
                                                    torch.tensor([reward]))
        self.curr_state[arm] = arm_new_state[0]

    
    def sample_action(self, return_extra=False):
        post_draws = self.seq_model.get_posterior_draws(Z_input=self.Z_representation, 
                                                     curr_state=self.curr_state, 
                                                     num_imagined=self.num_imagined,
                                                     num_repetitions=1)
        if self.randomly_break_ties: 
            best_arms = argmax_random_tiebreak(post_draws).item()
        else:
            best_arms = torch.argmax(post_draws).item()
        if return_extra:
            return best_arms, post_draws
        return best_arms


class SequentialAlgWithContext(BanditAlgorithm):
    # imputation bandit algorithm for sequential models with context
    def __init__(self, model, Z, num_arms, T, X, 
                 get_ttp=None, train_ttp=None, 
                 simple_logistic=True,
                 simple_xgb=False,
                 ignore_context=False,
                 num_imagined = 500,
                 finite_horizon_alg=False,
                 no_shuffle_boot=False):
        
        assert simple_logistic or ( (get_ttp is not None) and (train_ttp is not None) ) or simple_xgb
        assert not (simple_logistic and simple_xgb) 
        # imputation sequential model with context
        self.model = model 
        self.Z = Z
        self.T = T
        self.num_arms = num_arms
        assert Z.shape[0] == self.num_arms
        self.num_imagined = num_imagined
        self.finite_horizon_alg = finite_horizon_alg
        self.no_shuffle_boot = no_shuffle_boot

        self.X = X
        self.eval_X = X # this is what you do table imputation over. We could make this different. 
                
        # ttp = table to policy
        self.ignore_context = ignore_context
        self.simple_logistic = simple_logistic
        self.simple_xgb = simple_xgb
        
        self.get_ttp = get_ttp
        self.train_ttp = train_ttp
        self.Xdim = X.shape[-1]
        self.Zdim = Z.shape[-1]
         
        # things that increment / accumulate
        self.t = 0
        # accumulate observations observed by bandit algorithm
        self.hist_X = torch.zeros(self.num_arms, self.T, self.Xdim)
        self.hist_Y = torch.zeros(self.num_arms, self.T)
        self.hist_mask = torch.zeros(self.num_arms, self.T)
        
    def update_algorithm(self, arm, reward, X): 
        # todo check shapes here
        self.hist_X[arm,[self.t],:] = X[arm]
        self.hist_Y[arm,self.t] = reward
        self.hist_mask[arm,self.t] = 1
        self.t += 1
    
    def sample_action(self, X, return_extra=False):
        # generate TTP training data. 
        # inputs are just X's, since we have a separate model per row/arm. 
        # eval_X: data to generate on
        # X (in arguments): current user features (apply learned model to this X)
        
        if self.finite_horizon_alg:
            ttp_inputs = self.eval_X
            ttp_labels = self.model.fill_table_naive_finite(self.Z, self.hist_X, self.hist_Y, self.hist_mask, self.eval_X).detach()
        else:
            # subsample inputs up to length self.num_imagined:
            if self.no_shuffle_boot:
                assert self.eval_X.shape[1] == self.num_imagined
                idxs = torch.arange(self.eval_X.shape[1])
            else:
                idxs = torch.randint(0, self.eval_X.shape[1], (self.num_imagined,))
            ttp_inputs = torch.gather(self.eval_X, 1, idxs.unsqueeze(0).unsqueeze(-1).repeat((1,1,self.Xdim))).repeat(self.eval_X.shape[0],1,1)
            ttp_labels = self.model.fill_table_naive(self.Z, self.hist_X, self.hist_Y, self.hist_mask, ttp_inputs).detach()
        
        if self.ignore_context:
            preds = ttp_labels.mean(1)
        else:
            if self.simple_logistic:
                from sklearn.linear_model import LogisticRegression
                preds = []

                for idx in range(self.num_arms):
                    if ttp_labels[idx].min() == ttp_labels[idx].max():
                        preds.append(torch.ones(len(X[0])) * ttp_labels[idx].max()) 
                    else:
                        ttp_model = LogisticRegression()
                        ttp_model.fit(ttp_inputs[idx].numpy(), ttp_labels[idx].numpy())
                        preds.append(ttp_model.predict_proba(X[0])[:,1])
            elif self.simple_xgb:
                from sklearn.ensemble import GradientBoostingClassifier
                preds = []

                for idx in range(self.num_arms):
                    if ttp_labels[idx].min() == ttp_labels[idx].max():
                        preds.append(torch.ones(len(X[0])) * ttp_labels[idx].max()) 
                    else:
                        ttp_model = GradientBoostingClassifier()
                        ttp_model.fit(ttp_inputs[idx].numpy(), ttp_labels[idx].numpy())
                        preds.append(ttp_model.predict_proba(X[0])[:,1])
                
            else:
                ttps_per_arm = []
                for _ in range(self.num_arms):
                    ttps_per_arm.append(self.get_ttp(self, in_dim=self.Xdim))

                # train
                for idx in range(self.num_arms):
                    ttp_model, ttp_criterion, ttp_optimizer = ttps_per_arm[idx]
                    ttp_model = self.train_ttp(ttp_model, ttp_criterion, ttp_optimizer, ttp_inputs[idx], ttp_labels[idx])

                # eval
                preds = []
                for idx in range(self.num_arms):
                    ttp_model, ttp_criterion, ttp_optimizer = ttps_per_arm[idx]
                    ttp_model.eval()
                    preds.append(ttp_model(X[0]).detach().item())

        p_pred = torch.tensor(np.array(preds))        
        best_arm = torch.argmax(p_pred).item()
        if return_extra:
            return best_arm, p_pred
        return best_arm


class PosteriorHallucinationAlg_horizonDependent(BanditAlgorithm):
    def __init__(self, seq_model, Z_representation, num_arms, T, seed=None):
        super(PosteriorHallucinationAlg_horizonDependent, self).__init__(num_arms, seed)
        self.seq_model = seq_model
        self.Z_representation = Z_representation
        self.curr_state = seq_model.init_model_states(batch_size=num_arms)
        self.T = T
        self.past_rewards = { k : [] for k in range(num_arms)}
        if Z_representation is not None:
            assert len(Z_representation) == num_arms

    
    def update_algorithm(self, arm, reward):
        arm_curr_state = self.curr_state[[arm]]
        arm_new_state = self.seq_model.update_state(arm_curr_state, 
                                                    torch.tensor([reward]))
        self.curr_state[arm] = arm_new_state[0]
        self.past_rewards[arm].append(reward)

    def sample_action(self, return_extra=False):
        all_rewards = [ torch.Tensor(self.past_rewards[k]) for k in range(self.num_arms) ]
        
        post_draws = self.seq_model.get_posterior_draws_horizonDependent(Z_input=self.Z_representation, 
                                                     curr_state=self.curr_state, T=self.T,
                                                     past_obs=all_rewards,
                                                     num_repetitions=1)
        if return_extra:
            return torch.argmax(post_draws).item(), post_draws
        return torch.argmax(post_draws).item()


# SQUARE-CB BANDIT ALGORITHMS ============================================================

# UNTESTED AND UNUSED, maybe remove
class SquareCB(BanditAlgorithm):
    # follows https://arxiv.org/pdf/2002.04926.pdf
    # learning rate schedule follows https://arxiv.org/pdf/2010.03104
        # doesn't follow https://proceedings.mlr.press/v80/foster18a/foster18a.pdf
    
    def __init__(self, seq_model, Z_representation, num_arms, T, seed=None, hyparam_dict=None):
        super(SquareCB, self).__init__(num_arms, seed)
        self.mu = num_arms                          # they prove regret bound for this case
        self.seq_model = seq_model
        self.Z_representation = Z_representation
        self.curr_state = seq_model.init_model_states(batch_size=num_arms)
        self.T = T
        self.t = 0.0

        # learning rate gamma is set this way in https://arxiv.org/pdf/2010.03104
        self.gamma0 = hyparam_dict['gamma0']
        self.rho = hyparam_dict['rho']
        if Z_representation is not None:
            assert len(Z_representation) == num_arms
    
    def update_algorithm(self, arm, reward):
        arm_curr_state = self.curr_state[[arm]]
        arm_new_state = self.seq_model.update_state(arm_curr_state, 
                                                    torch.tensor([reward]))
        self.curr_state[arm] = arm_new_state[0]
        self.t += 1

    def sample_action(self):
        self.lr = self.gamma0 * self.t ** self.rho
        with torch.no_grad():
            if self.Z_representation is not None:
                state = torch.cat([self.Z_representation, self.curr_state], 1) 
            else:
                state = self.curr_state
            p_hats = self.seq_model.top_layer(state).squeeze().numpy()
            
            b = np.argmax(p_hats)
            phat_max = p_hats[b]
    
            new_phats = 1/ (self.mu + self.lr*(phat_max - p_hats) )
            not_max_ind = np.not_equal( np.arange(self.num_arms), b )
            new_phats[b] = 1 - np.dot(new_phats.squeeze(), not_max_ind)
    
            assert np.isclose( np.sum( new_phats ), 1 )

        return np.argmax( self.rng_gen.multinomial(1, new_phats) )


# NEURAL LINEAR BANDIT ALGORITHMS ============================================================

class NeuralLinearAlg(BanditAlgorithm):
    def __init__(self, marginal_preds, num_arms, hyparam_dict, seed=None):
        super(NeuralLinearAlg, self).__init__(num_arms, seed)
        self.prior_means = marginal_preds

        if 'sigma_squared' in hyparam_dict.keys():
            hyparam_dict['prior_var'] = hyparam_dict['sigma_squared']
            hyparam_dict['noise_var'] = hyparam_dict['s_squared']
        self.prior_vars = np.ones(num_arms) * hyparam_dict['prior_var']
        self.noise_var = hyparam_dict['noise_var']

        assert len(marginal_preds) == num_arms

        self.post_means = np.copy(self.prior_means)   # set to prior mean
        self.post_vars = np.copy(self.prior_vars)     # set to prior variance
        
        self.num_obs = np.zeros(num_arms)
        self.reward_sum = np.zeros(num_arms)
        
    def update_algorithm(self, arm, reward):
        self.num_obs[arm] += 1
        self.reward_sum[arm] += reward
        
        self.post_vars[arm] = 1 / ( 1/self.prior_vars[arm] + self.num_obs[arm] / self.noise_var )
        self.post_means[arm] = self.post_vars[arm] * (self.prior_means[arm] / self.prior_vars[arm] + self.reward_sum[arm] / self.noise_var)
    
    def sample_action(self):        
        samples = self.rng_gen.normal(self.post_means, self.post_vars ** 0.5)
        return samples.argmax().item()


# BETA BERNOULLI BANDIT ALGORITHMS ============================================================
from torch.distributions import Beta

class BetaBernoulliAlg(BanditAlgorithm):
    def __init__(self, num_arms, hyparam_dict, seed=None):
        super(BetaBernoulliAlg, self).__init__(num_arms, seed)
        self.ones = torch.zeros(num_arms)
        self.zeros = torch.zeros(num_arms)
        self.alpha = hyparam_dict['alpha']
        self.beta = hyparam_dict['beta']
    
    def update_algorithm(self, arm, reward):
        if not isinstance(reward, int):
            reward = reward.item()
        if reward == 1:
            self.ones[arm] += 1
        elif reward == 0:
            self.zeros[arm] += 1
        else:
            raise ValueError('Reward must be binary for beta bernoulli bandit')
    
    def sample_action(self):
        m = Beta(self.alpha + self.ones, self.beta + self.zeros)
        samples = m.sample()
        return samples.argmax()

class BetaBernoulliMixtureAlg(BanditAlgorithm):
    def __init__(self, num_arms, hyparam_dict, seed=None):
        super(BetaBernoulliMixtureAlg, self).__init__(num_arms, seed)
        self.ones = torch.zeros(num_arms)
        self.zeros = torch.zeros(num_arms)
        self.alpha1 = hyparam_dict['alpha1']
        self.alpha2 = hyparam_dict['alpha2']
        self.beta1 = hyparam_dict['beta1']
        self.beta2 = hyparam_dict['beta2']
        self.mixweight = hyparam_dict['mixweight']

    def update_algorithm(self, arm, reward):
        if not isinstance(reward, int):
            reward = reward.item()
        if reward == 1:
            self.ones[arm] += 1
        elif reward == 0:
            self.zeros[arm] += 1
        else:
            raise ValueError('Reward must be binary for beta bernoulli bandit')
        
    def sample_action(self):
        post_alpha1 = self.alpha1 + self.ones
        post_alpha2 = self.alpha2 + self.ones
        post_beta1 = self.beta1 + self.zeros
        post_beta2 = self.beta2 + self.zeros
        # Formula for marginal likelihood is here on page 24 of
        # https://www2.stat.duke.edu/~rcs46/modern_bayes17/lecturesModernBayes17/lecture-1/01-intro-to-Bayes.pdf
        clog1 = special.betaln(post_alpha1, post_beta1) - special.betaln(self.alpha1, self.beta1)
        clog2 = special.betaln(post_alpha2, post_beta2) - special.betaln(self.alpha2, self.beta2)
        bigconst = -clog1
        c1 = self.mixweight * torch.exp( clog1 + bigconst )
        c2 = (1-self.mixweight) * torch.exp( clog2 + bigconst )
        post_mixweight = c1 / (c1+c2)

        beta1 = torch.distributions.beta.Beta(post_alpha1, post_beta1)
        beta2 = torch.distributions.beta.Beta(post_alpha2, post_beta2)
        samples1 = beta1.sample()
        samples2 = beta2.sample()
        ind = torch.bernoulli(post_mixweight * torch.ones_like(post_mixweight))
        post_samples = samples1 * ind + samples2 * (1-ind)
        return post_samples.argmax()


class UCBAlg(BanditAlgorithm):
    # following https://papers.nips.cc/paper_files/paper/2011/file/e1d5be1c7f2f456670de3d53c7b54f4a-Paper.pdf

    def __init__(self, num_arms, delta=0.1, seed=None):
        super(UCBAlg, self).__init__(num_arms, seed)
        self.arm_means = np.zeros(num_arms)
        self.arm_counts = np.zeros(num_arms)
        self.delta = delta
        self.subgaussian_sigma = 0.5

    def update_algorithm(self, arm, reward):
        new_arm_sum = self.arm_means[arm] * self.arm_counts[arm] + reward
        self.arm_means[arm] = new_arm_sum / ( self.arm_counts[arm] + 1 )
        self.arm_counts[arm] += 1

    def compute_ucb(self):
        part1 = (1+self.arm_counts) / np.square(self.arm_counts)
        part2 = 1 + 2 * np.log( self.num_arms * np.sqrt(1+self.arm_counts) / self.delta )
        return self.subgaussian_sigma * np.sqrt( part1 * part2 )

    def sample_action(self):
        not_selected_ind = np.equal( self.arm_counts, 0 )*1
        if np.any(not_selected_ind):
            return np.argmax( not_selected_ind )

        # compute UCB
        arm_scores = self.arm_means + self.compute_ucb()
        return np.argmax( arm_scores )



class LinUCBAlg(BanditAlgorithm):
    # following https://arxiv.org/pdf/1003.0146
    # https://papers.nips.cc/paper_files/paper/2011/file/e1d5be1c7f2f456670de3d53c7b54f4a-Paper.pdf

    def __init__(self, num_arms, X_dim, delta=0.1, seed=None):
        super(LinUCBAlg, self).__init__(num_arms, seed)
        self.X_dim = X_dim
        self.delta = 0.1
        self.subgaussian_sigma = 0.5
        self.total_dim = X_dim + 1

        self.arm_sums = np.zeros((num_arms, self.total_dim))
        self.arm_invcov = np.stack([ np.eye(self.total_dim) for x in range(num_arms)])

        #Z_vecs=None, 
        #self.Z_vecs = Z_vecs
        #if Z_vecs is not None:
        #    self.total_dim += Z_vecs.shape[1]

    def update_algorithm(self, arm, reward_torch, X_torch):
        X_raw = X_torch.squeeze(1).numpy()
        X = np.concatenate( [[1], X_raw[0,:]] )
        reward = reward_torch.numpy()
        self.arm_sums[arm] = self.arm_sums[arm] + X*reward
        self.arm_invcov[arm] = self.arm_invcov[arm] + np.einsum('j,k->jk', X, X)

    def compute_score(self, X):
        arm_cov = np.linalg.inv(self.arm_invcov)
        thetahats = np.einsum('ijk,ik->ij', arm_cov, self.arm_sums)
        means = np.einsum('ij,ij->i', thetahats, X)

        # Compute upper confidence bound
        norm = np.sqrt( np.einsum('ij,ij->i', np.einsum('ijk,ik->ij', self.arm_invcov, X), X) )
        determinants = np.linalg.det(self.arm_invcov)
        RHS = 1 + self.subgaussian_sigma * np.sqrt(2 * np.log( np.sqrt( determinants ) / self.delta ) )
        scale = RHS / norm
        ucb = np.einsum('ij,ij->i', X, X) * (scale**2)
        return means + ucb

    def sample_action(self, X):
        # compute UCB
        X = X.squeeze(1).numpy()
        new_X = np.concatenate( [np.ones( (X.shape[0], 1) ), X], 1 ) # add an intercept term
        
        arm_scores = self.compute_score(new_X)
        return np.argmax( arm_scores )


def run_bandit(env, alg, T, num_round_robin=0, context=False, return_extra=False):
    all_rewards = []
    all_exp_rewards = []
    action_arms = []
    X = None
    all_extras = []
    for t in range(T):
        extras = None
        if context:
            # This is returning one copy of the context per action
            X = env.get_context(t) #X[:,[t],:]
        if t < num_round_robin * env.num_arms:
            arm = t % env.num_arms
        else:
            if context:
                if return_extra:
                    arm, extras = alg.sample_action(X, return_extra=True)
                else:
                    arm = alg.sample_action(X)
            else:
                if return_extra:
                    arm, extras = alg.sample_action(return_extra=True)
                else:
                    arm = alg.sample_action()
        #print("arm", arm)
        reward = env.generate_reward(arm, t)
        if context:
            alg.update_algorithm(arm, reward, X)
        else:
            alg.update_algorithm(arm, reward)
        all_rewards.append(reward)

        # Get a less noisy estimate of the reward
        exp_reward = env.get_expected_reward(arm, t)
        all_exp_rewards.append(exp_reward)
        action_arms.append(arm)
        if extras is not None:
            all_extras.append(extras)
        
    res = { 'rewards':np.array(all_rewards), 
             'expected_rewards': np.array(all_exp_rewards), 
             'action_arms': np.array(action_arms) }
    if return_extra:
        res['extras'] = all_extras
    return res


############## DPT ##############

class DPTSequenceAlg(BanditAlgorithm):
    def __init__(self, seq_model, Z_representation, num_arms):
        super(DPTSequenceAlg, self).__init__()
        self.seq_model = seq_model
        self.Z_representation = Z_representation
        self.curr_state = seq_model.init_model_states(batch_size=num_arms)
        if Z_representation is not None:
            assert len(Z_representation) == num_arms

    
    def update_algorithm(self, arm, reward):
        arm_curr_state = self.curr_state[[arm]]
        arm_new_state = self.seq_model.update_state(arm_curr_state, 
                                                    torch.tensor([reward]))
        self.curr_state[arm] = arm_new_state[0]

    
    def sample_action(self, return_extra=False):
        post_draws = self.seq_model.dpt_sampling(self.curr_state,self.Z_representation) 
        arm = torch.multinomial(torch.softmax(post_draws.flatten(),0),1).item()
        if return_extra:
            return arm, post_draws
        return arm


